Explain Blackbox Regressors#
In this notebook we will use the interpret package to explain blackbox regressors using SHAP, Lime, MorrisSensitivity, and PartialDependence.
This notebook can be found in our examples folder on GitHub.
# install interpret if not already installed
try:
import interpret
except ModuleNotFoundError:
!pip install --quiet interpret pandas scikit-learn lime
import numpy as np
import pandas as pd
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from interpret import show
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
X, y = load_diabetes(return_X_y=True, as_frame=True)
seed = 42
np.random.seed(seed)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)
Train a blackbox regression system
from sklearn.ensemble import RandomForestRegressor
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
#Blackbox system can include preprocessing, not just a regressor!
pca = PCA()
rf = RandomForestRegressor(random_state=seed)
blackbox_model = Pipeline([('pca', pca), ('rf', rf)])
blackbox_model.fit(X_train, y_train)
Pipeline(steps=[('pca', PCA()), ('rf', RandomForestRegressor(random_state=42))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('pca', PCA()), ('rf', RandomForestRegressor(random_state=42))])PCA()
RandomForestRegressor(random_state=42)
Show blackbox model performance
from interpret.perf import RegressionPerf
blackbox_perf = RegressionPerf(blackbox_model).explain_perf(X_test, y_test, name='Blackbox')
show(blackbox_perf)
Local Explanations: How an individual prediction was made
from interpret.blackbox import LimeTabular
#Blackbox explainers need a predict function, and optionally a dataset
lime = LimeTabular(blackbox_model, X_train, random_state=1)
#Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local(X_test[:5], y_test[:5], name='LIME')
show(lime_local, 0)
from interpret.blackbox import ShapKernel
background_val = pd.DataFrame(np.median(X_train, axis=0).reshape(1, -1), columns=X.columns)
shap = ShapKernel(blackbox_model, background_val)
shap_local = shap.explain_local(X_test[:5], y_test[:5], name='SHAP')
show(shap_local, 0)
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
Traceback (most recent call last): File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/traitlets/config/application.py", line 1043, in launch_instance
app.start()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 725, in start
self.io_loop.start()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
self.asyncio_loop.run_forever()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
self._run_once()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
handle._run()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 513, in dispatch_queue
await self.process_one()
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 502, in process_one
await dispatch(*args)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 409, in dispatch_shell
await result
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
reply_content = await reply_content
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
res = shell.run_cell(
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
return super().run_cell(*args, **kwargs)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2961, in run_cell
result = self._run_cell(
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3016, in _run_cell
result = runner(coro)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3221, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3400, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_3606/2654330808.py", line 4, in <module>
shap = ShapKernel(blackbox_model, background_val)
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpret/blackbox/_shap.py", line 32, in __init__
from shap import KernelExplainer
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/__init__.py", line 4, in <module>
from .explainers import other
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/__init__.py", line 4, in <module>
from ._gpu_tree import GPUTreeExplainer
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/_gpu_tree.py", line 5, in <module>
from ._tree import (
File "/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/explainers/_tree.py", line 29, in <module>
from .. import _cext
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
AttributeError: _ARRAY_API not found
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 4
1 from interpret.blackbox import ShapKernel
3 background_val = pd.DataFrame(np.median(X_train, axis=0).reshape(1, -1), columns=X.columns)
----> 4 shap = ShapKernel(blackbox_model, background_val)
5 shap_local = shap.explain_local(X_test[:5], y_test[:5], name='SHAP')
6 show(shap_local, 0)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpret/blackbox/_shap.py:32, in ShapKernel.__init__(self, model, data, feature_names, feature_types, **kwargs)
21 def __init__(self, model, data, feature_names=None, feature_types=None, **kwargs):
22 """Initializes class.
23
24 Args:
(...)
29 **kwargs: Kwargs that will be sent to shap.KernelExplainer
30 """
---> 32 from shap import KernelExplainer
34 self.model = model
35 self.feature_names = feature_names
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/__init__.py:45
43 have_matplotlib = False
44 if have_matplotlib:
---> 45 from . import plots
46 from .plots._bar import bar_legacy as bar_plot
47 from .plots._beeswarm import summary_legacy as summary_plot
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/__init__.py:6
3 except ImportError:
4 raise ImportError("matplotlib is not installed so plotting is not available! Run `pip install matplotlib` to fix this.")
----> 6 from ._bar import bar
7 from ._beeswarm import beeswarm
8 from ._benchmark import benchmark
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/_bar.py:9
7 from ..utils import format_value, ordinal_str
8 from ..utils._exceptions import DimensionError
----> 9 from . import colors
10 from ._labels import labels
11 from ._utils import (
12 convert_ordering,
13 dendrogram_coords,
(...)
16 sort_inds,
17 )
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/__init__.py:1
----> 1 from ._colors import (
2 blue_rgb,
3 gray_rgb,
4 light_blue_rgb,
5 light_red_rgb,
6 red_blue,
7 red_blue_circle,
8 red_blue_no_bounds,
9 red_blue_transparent,
10 red_rgb,
11 red_transparent_blue,
12 red_white_blue,
13 transparent_blue,
14 transparent_red,
15 )
17 __all__ = [
18 "blue_rgb",
19 "gray_rgb",
(...)
30 "transparent_red",
31 ]
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colors.py:23
21 red_lch = [54., 90., 0.35470565 + 2* np.pi]
22 gray_lch = [55., 0., 0.]
---> 23 blue_rgb = lch2rgb(blue_lch)
24 red_rgb = lch2rgb(red_lch)
25 gray_rgb = lch2rgb(gray_lch)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colors.py:13, in lch2rgb(x)
12 def lch2rgb(x):
---> 13 return lab2rgb(lch2lab([[x]]))[0][0]
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:372, in lch2lab(lch)
346 def lch2lab(lch):
347 """CIE-LCH to CIE-LAB color space conversion.
348 LCH is the cylindrical representation of the LAB (Cartesian) colorspace
349 Parameters
(...)
370 >>> img_lab2 = lch2lab(img_lch)
371 """
--> 372 lch = _prepare_lab_array(lch)
374 c, h = lch[..., 1], lch[..., 2]
375 lch[..., 1], lch[..., 2] = c * np.cos(h), c * np.sin(h)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:387, in _prepare_lab_array(arr)
385 if shape[-1] < 3:
386 raise ValueError('Input array has less than 3 color channels')
--> 387 return img_as_float(arr, force_copy=True)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:993, in img_as_float(image, force_copy)
972 def img_as_float(image, force_copy=False):
973 """Convert an image to floating point format.
974 This function is similar to `img_as_float64`, but will not convert
975 lower-precision floating point arrays to `float64`.
(...)
991 and can be outside the ranges [0.0, 1.0] or [-1.0, 1.0].
992 """
--> 993 return convert(image, np.floating, force_copy)
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/shap/plots/colors/_colorconv.py:819, in convert(image, dtype, force_copy, uniform)
808 itemsize_out = dtypeobj_out.itemsize
810 # Below, we do an `issubdtype` check. Its purpose is to find out
811 # whether we can get away without doing any image conversion. This happens
812 # when:
(...)
816 # is a subclass of that type (e.g. `np.floating` will allow
817 # `float32` and `float64` arrays through)
--> 819 if np.issubdtype(dtype_in, np.obj2sctype(dtype)):
820 if force_copy:
821 image = image.copy()
File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.obj2sctype` was removed in the NumPy 2.0 release. Use `np.dtype(obj).type` instead.
Global Explanations: How the model behaves overall
from interpret.blackbox import MorrisSensitivity
sensitivity = MorrisSensitivity(blackbox_model, X_train)
sensitivity_global = sensitivity.explain_global(name="Global Sensitivity")
show(sensitivity_global)
from interpret.blackbox import PartialDependence
pdp = PartialDependence(blackbox_model, X_train)
pdp_global = pdp.explain_global(name='Partial Dependence')
show(pdp_global, 0)